
########################################################################################################################
# Imports
########################################################################################################################

import tensorflow as tf
import sys

sys.path.append("../")
########################################################################################################################
# UNet
########################################################################################################################
from Architectures.srmodels import resolve_single
from Architectures.srmodels.edsr import edsr
class EDSRNet:

    ########################################################################################################################
    # downsample
    ########################################################################################################################

    

    ########################################################################################################################
    # __init__
    ########################################################################################################################

    def __init__(self, inputChannels, outputChannels, inputResolutionU, inputResolutionV,scale= 1,num_blocks=4):

        self.outputChannels = outputChannels

       
        self.model = edsr(input_channels=inputChannels,output_channels=outputChannels,scale=scale, num_res_blocks=num_blocks)



        print(self.model.summary())

    ########################################################################################################################
    # Additional backbone
    ########################################################################################################################


if __name__=="__main__":
    SRNet = EDSRNet(75, 3, 1028,752,scale=4)
    import time
    for i in range(10):

        temp=tf.zeros((1,752,1028,75))
        a=time.time()
        c=SRNet.model(temp)
        #
       # print(c[0])
        b=time.time()
        #y=c.numpy()
        print(a-b)